#    Copyright (C) 2018  Boris Bobrov
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <https://www.gnu.org/licenses/>.

import json
import logging

import daiquiri

daiquiri.setup(level=logging.INFO)
logger = daiquiri.getLogger(__name__)

FIREWALL_URL = ('https://www.nfoservers.com/control/firewall.pl?'
                'name={name}&typeofserver={typeofserver}')
DEFAULT_FIELDS_NUM = 25


class FirewallIsFullError(Exception):
    pass

class NoMatchingDescriptionError(Exception):
    pass


class FirewallRules:
    def __init__(self, rules, fields_num=DEFAULT_FIELDS_NUM):
        self.rules = rules
        self.fields_num = fields_num

    def get_idx_by_description(self, description):
        for i, r in enumerate(self.rules):
            if r['description'] == description:
                return i
        else:
            raise NoMatchingDescriptionError('No such description')

    def sanity_check(self):
        self._check_number()
        self._check_numberoption()

    def _check_numberoption(self):
        numberoptions = [r['firewall_numberoption'] for r in self.rules]
        if (list(sorted(int(x) for x in numberoptions)) !=
                list(range(1, self.fields_num + 1))):
            raise AssertionError('Something is wrong with numberoptions')

    def _check_number(self):
        # rules should be deleted by setting 'firewall_easychoice' to 'choose'
        if len(self.rules) != self.fields_num:
            raise AssertionError('Not enough rules')

    def prepare_new_rule_idx(self, number):
        '''Safely prepare new rule.

        Creating rules is hard. A number for it needs to be picked and the
        number of the next rule needs to be bumped. This method does that.
        For convenience and due to implementation details of the firewall page,
        we will always edit the last rule in the python list. If the last rule
        is already occupied, the method will raise an error.

        Returns index of the rule to edit, which is always the last one.
        '''
        last_rule = self.rules[-1]
        if (last_rule['firewall_easychoice'] != 'choose' or
                last_rule['firewall_filterid']):
            raise FirewallIsFull

        for r in self.rules:
            numberoption = int(r['firewall_numberoption'])
            if numberoption >= number:
                r['firewall_numberoption'] = str(numberoption+1)

        self.rules[-1]['firewall_numberoption'] = str(number)

        return len(self.rules) - 1


class Firewall:

    def __init__(self, server, fields_num=DEFAULT_FIELDS_NUM):
        self.server = server
        self.fields_num = fields_num

    def fetch_rules(self):
        url = FIREWALL_URL.format(name=self.server.servername,
                                  typeofserver=self.server.typeofserver)
        self.server.go(url)
        self.server.g.doc.choose_form(name='rules_form')
        fields = self.server.g.doc.form_fields()
        return self._fields_to_rules(fields)

    def update_rules(self, rules, sanity_check=True, idxs_updated=None):
        fields = self._rules_to_fields(rules, idxs=idxs_updated,
                sanity_check=sanity_check)
        url = FIREWALL_URL.format(name=self.server.servername,
                                  typeofserver=self.server.typeofserver)
        self.server.go(url)
        self.server.g.doc.choose_form(name='rules_form')
        fields_backup = self.server.g.doc.form_fields()

        def _do_submit_fields(fields):
            for k, v in fields.items():
                self.server.g.doc.set_input(k, v)
            self.server.submit()
            expected = ('Your changes have been saved and should take effect'
                        ' shortly')
            try:
                self.server.g.doc.text_assert(expected)
            except Exception as e:
                logger.error(self.server.g.doc(
                    '//div[@class="errormessage"]').text())
                raise

        try:
            _do_submit_fields(fields)
        except Exception as e:
            logger.exception('Error happened on upload, restoring original'
                             ' rules')
            try:
                _do_submit_fields(fields_backup)
                logger.info('Original rules restored')
            except Exception as e:
                logger.exception('Error restoring original rules')
                logger.error('Here are rules that we could not restore:')
                import json
                logger.error(json.dumps(fields_backup))
                raise e


    def _fields_to_rules(self, fields: dict) -> FirewallRules:
        rules = [{} for i in range(self.fields_num)]
        for k, v in fields.items():
            # typical fields look like 'firewall_easychoice_9' with
            # the number in the end as rule number
            key_parts = k.split('_')
            try:
                idx = int(key_parts[-1])
            except ValueError:
                # the field is not a filter rule, but probably a service
                # field. Skip it.
                continue
            field_name = '_'.join(key_parts[:-1])
            rules[idx][field_name] = v
        return FirewallRules(rules)

    def _rules_to_fields(self, rules: FirewallRules,
            sanity_check: bool, idxs=None) -> dict:
        fields = {}
        if sanity_check:
            rules.sanity_check()
        for i, rule in enumerate(rules.rules):
            if idxs and i not in idxs:
                continue
            for k, v in rule.items():
                fields['{}_{}'.format(k, i)] = v
        return fields
